1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
17 
18 #include <map>
19 #include <memory>
20 
21 #include "tensorflow/compiler/xla/literal_util.h"
22 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
23 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
24 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
25 #include "tensorflow/compiler/xla/service/instruction_fusion.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/test.h"
28 #include "tensorflow/compiler/xla/test_helpers.h"
29 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/core/platform/test.h"
33 
34 namespace op = xla::testing::opcode_matchers;
35 
36 namespace xla {
37 namespace {
38 
39 using ::testing::UnorderedElementsAre;
40 using ::testing::UnorderedElementsAreArray;
41 
42 class TuplePointsToAnalysisTest : public HloTestBase {
43  protected:
44   // Builds a module with the given entry computation and runs points to
45   // analysis.
BuildModuleAndRunAnalysis(std::unique_ptr<HloComputation> computation)46   void BuildModuleAndRunAnalysis(std::unique_ptr<HloComputation> computation) {
47     BuildModule(std::move(computation));
48     RunAnalysis();
49   }
50 
BuildModule(std::unique_ptr<HloComputation> computation)51   void BuildModule(std::unique_ptr<HloComputation> computation) {
52     module_ = CreateNewUnverifiedModule();
53     module_->AddEntryComputation(std::move(computation));
54   }
55 
RunAnalysis()56   void RunAnalysis() {
57     CHECK_NOTNULL(module_.get());
58     points_to_analysis_ =
59         TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie();
60   }
61 
62   // Returns the LogicalBuffer defined at the given instruction and
63   // index. CHECKs if no buffer is defined at that point.
GetBuffer(const HloInstruction * instruction,const ShapeIndex & index)64   const LogicalBuffer* const GetBuffer(const HloInstruction* instruction,
65                                        const ShapeIndex& index) {
66     const auto& pointed_to =
67         points_to_analysis_->GetPointsToSet(instruction).element(index);
68     CHECK_EQ(1, pointed_to.size());
69     CHECK_EQ(instruction, pointed_to[0]->instruction());
70     CHECK(index == pointed_to[0]->index());
71     return pointed_to[0];
72   }
73 
74   // Checks that the given points-to set contains exactly (unordered) the given
75   // LogicalBuffers.
ExpectHasBuffers(const PointsToSet::BufferList & points_to_set,absl::Span<const LogicalBuffer * const> buffers)76   void ExpectHasBuffers(const PointsToSet::BufferList& points_to_set,
77                         absl::Span<const LogicalBuffer* const> buffers) {
78     std::vector<const LogicalBuffer*> vec(buffers.begin(), buffers.end());
79     EXPECT_THAT(points_to_set, UnorderedElementsAreArray(vec));
80   }
81 
82   // Checks that the given points-to set contains exactly (unordered) the
83   // top-level buffers of the given instructions.
ExpectHasTopLevelBuffers(const PointsToSet::BufferList & points_to_set,absl::Span<HloInstruction * const> instructions)84   void ExpectHasTopLevelBuffers(
85       const PointsToSet::BufferList& points_to_set,
86       absl::Span<HloInstruction* const> instructions) {
87     PointsToSet::BufferList buffers;
88     for (auto instruction : instructions) {
89       buffers.push_back(GetBuffer(instruction, /*index=*/{}));
90     }
91     ExpectHasBuffers(points_to_set, buffers);
92   }
93 
94   // Overload which takes a set instead of a vector.
ExpectHasTopLevelBuffers(const PointsToSet::BufferSet & points_to_set,absl::Span<HloInstruction * const> instructions)95   void ExpectHasTopLevelBuffers(
96       const PointsToSet::BufferSet& points_to_set,
97       absl::Span<HloInstruction* const> instructions) {
98     ExpectHasTopLevelBuffers(
99         PointsToSet::BufferList(points_to_set.begin(), points_to_set.end()),
100         instructions);
101   }
102 
103   // Checks that the buffer defined at the given instruction and index has
104   // aliases which are exactly (unordered) the given instruction/index pairs.
ExpectHasBufferAliases(const HloInstruction * instruction,const ShapeIndex & index,absl::Span<const std::pair<HloInstruction *,ShapeIndex>> expected)105   void ExpectHasBufferAliases(
106       const HloInstruction* instruction, const ShapeIndex& index,
107       absl::Span<const std::pair<HloInstruction*, ShapeIndex>> expected) {
108     const LogicalBuffer* buffer =
109         points_to_analysis_->GetBufferDefinedAt(instruction, index)
110             .ValueOrDie();
111     std::vector<BufferAlias> expected_aliases;
112     for (auto& pair : expected) {
113       expected_aliases.push_back(BufferAlias(pair.first, pair.second));
114     }
115     EXPECT_THAT(points_to_analysis_->GetBufferAliases(*buffer),
116                 UnorderedElementsAreArray(expected_aliases));
117   }
118 
119   std::unique_ptr<HloModule> module_;
120   std::unique_ptr<TuplePointsToAnalysis> points_to_analysis_;
121 };
122 
TEST_F(TuplePointsToAnalysisTest,SimpleTuple)123 TEST_F(TuplePointsToAnalysisTest, SimpleTuple) {
124   auto builder = HloComputation::Builder(TestName());
125   auto constant1 = builder.AddInstruction(
126       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
127   auto constant2 = builder.AddInstruction(
128       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
129   auto tuple = builder.AddInstruction(
130       HloInstruction::CreateTuple({constant1, constant2}));
131 
132   BuildModuleAndRunAnalysis(builder.Build());
133   EXPECT_EQ(1, points_to_analysis_->GetPointsToSet(constant1).size());
134   ExpectHasTopLevelBuffers(
135       points_to_analysis_->GetPointsToSet(constant1).element({}), {constant1});
136   EXPECT_TRUE(
137       points_to_analysis_->GetPointsToSet(constant1).tuple_sources({}).empty());
138   EXPECT_TRUE(points_to_analysis_->GetPointsToSet(tuple).IsDistinct());
139 
140   EXPECT_EQ(1, points_to_analysis_->GetPointsToSet(constant2).size());
141   ExpectHasTopLevelBuffers(
142       points_to_analysis_->GetPointsToSet(constant2).element({}), {constant2});
143   EXPECT_TRUE(
144       points_to_analysis_->GetPointsToSet(constant2).tuple_sources({}).empty());
145 
146   EXPECT_EQ(3, points_to_analysis_->GetPointsToSet(tuple).size());
147   EXPECT_FALSE(points_to_analysis_->GetPointsToSet(tuple).IsAmbiguous());
148   EXPECT_THAT(points_to_analysis_->GetPointsToSet(tuple).tuple_sources({}),
149               UnorderedElementsAre(tuple));
150 
151   ExpectHasTopLevelBuffers(
152       points_to_analysis_->GetPointsToSet(tuple).CreateFlattenedSet(),
153       {constant1, constant2, tuple});
154   ExpectHasTopLevelBuffers(
155       points_to_analysis_->GetPointsToSet(tuple).element({}), {tuple});
156   ExpectHasTopLevelBuffers(
157       points_to_analysis_->GetPointsToSet(tuple).element({0}), {constant1});
158   ExpectHasTopLevelBuffers(
159       points_to_analysis_->GetPointsToSet(tuple).element({1}), {constant2});
160 
161   const PointsToSet& tuple_points_to_set =
162       points_to_analysis_->GetPointsToSet(tuple);
163   EXPECT_TRUE(tuple_points_to_set.ContainsBufferAtIndex(
164       *GetBuffer(constant1, {}), {0}));
165   EXPECT_TRUE(tuple_points_to_set.ContainsBufferAtIndex(
166       *GetBuffer(constant2, {}), {1}));
167   EXPECT_FALSE(tuple_points_to_set.ContainsBufferAtIndex(
168       *GetBuffer(constant2, {}), {0}));
169   EXPECT_TRUE(tuple_points_to_set.ContainsBuffer(*GetBuffer(constant1, {})));
170   EXPECT_TRUE(tuple_points_to_set.ContainsBuffer(*GetBuffer(constant2, {})));
171 }
172 
TEST_F(TuplePointsToAnalysisTest,NestedTuple)173 TEST_F(TuplePointsToAnalysisTest, NestedTuple) {
174   // Create a (nested) tuple containing an inner tuple. The points-to set of the
175   // outer tuple should contain all elements of the points-to set of the inner
176   // tuple.
177   auto builder = HloComputation::Builder(TestName());
178   auto constant1 = builder.AddInstruction(
179       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
180   auto constant2 = builder.AddInstruction(
181       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
182   auto inner_tuple = builder.AddInstruction(
183       HloInstruction::CreateTuple({constant1, constant2}));
184 
185   auto constant3 = builder.AddInstruction(
186       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
187   auto tuple = builder.AddInstruction(
188       HloInstruction::CreateTuple({inner_tuple, constant3}));
189 
190   BuildModuleAndRunAnalysis(builder.Build());
191   ExpectHasTopLevelBuffers(
192       points_to_analysis_->GetPointsToSet(constant1).element({}), {constant1});
193   ExpectHasTopLevelBuffers(
194       points_to_analysis_->GetPointsToSet(constant2).element({}), {constant2});
195   ExpectHasTopLevelBuffers(
196       points_to_analysis_->GetPointsToSet(constant3).element({}), {constant3});
197 
198   EXPECT_EQ(3, points_to_analysis_->GetPointsToSet(inner_tuple).size());
199   EXPECT_FALSE(points_to_analysis_->GetPointsToSet(inner_tuple).IsAmbiguous());
200   EXPECT_TRUE(points_to_analysis_->GetPointsToSet(inner_tuple).IsDistinct());
201   ExpectHasTopLevelBuffers(
202       points_to_analysis_->GetPointsToSet(inner_tuple).CreateFlattenedSet(),
203       {constant1, constant2, inner_tuple});
204   ExpectHasTopLevelBuffers(
205       points_to_analysis_->GetPointsToSet(inner_tuple).element({}),
206       {inner_tuple});
207   EXPECT_THAT(
208       points_to_analysis_->GetPointsToSet(inner_tuple).tuple_sources({}),
209       UnorderedElementsAre(inner_tuple));
210 
211   EXPECT_EQ(5, points_to_analysis_->GetPointsToSet(tuple).size());
212   EXPECT_FALSE(points_to_analysis_->GetPointsToSet(tuple).IsAmbiguous());
213   ExpectHasTopLevelBuffers(
214       points_to_analysis_->GetPointsToSet(tuple).CreateFlattenedSet(),
215       {constant1, constant2, constant3, inner_tuple, tuple});
216 
217   EXPECT_THAT(points_to_analysis_->GetPointsToSet(tuple).tuple_sources({}),
218               UnorderedElementsAre(tuple));
219   EXPECT_THAT(points_to_analysis_->GetPointsToSet(tuple).tuple_sources({0}),
220               UnorderedElementsAre(inner_tuple));
221   EXPECT_TRUE(
222       points_to_analysis_->GetPointsToSet(tuple).tuple_sources({1}).empty());
223 
224   ExpectHasTopLevelBuffers(
225       points_to_analysis_->GetPointsToSet(tuple).element({0}), {inner_tuple});
226   ExpectHasTopLevelBuffers(
227       points_to_analysis_->GetPointsToSet(tuple).element({0, 0}), {constant1});
228   ExpectHasTopLevelBuffers(
229       points_to_analysis_->GetPointsToSet(tuple).element({0, 1}), {constant2});
230   ExpectHasTopLevelBuffers(
231       points_to_analysis_->GetPointsToSet(tuple).element({1}), {constant3});
232 }
233 
TEST_F(TuplePointsToAnalysisTest,GetTupleElement)234 TEST_F(TuplePointsToAnalysisTest, GetTupleElement) {
235   // Create a nested tuple, then extract the inner tuple with GetTupleElement.
236   // The points-to set of the GetTupleElement should be the same as the inner
237   // tuple.
238   auto builder = HloComputation::Builder(TestName());
239   auto constant1 = builder.AddInstruction(
240       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
241   auto constant2 = builder.AddInstruction(
242       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
243   auto inner_tuple = builder.AddInstruction(
244       HloInstruction::CreateTuple({constant1, constant2}));
245 
246   auto constant3 = builder.AddInstruction(
247       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
248   auto tuple = builder.AddInstruction(
249       HloInstruction::CreateTuple({inner_tuple, constant3}));
250 
251   auto get_tuple_element = builder.AddInstruction(
252       HloInstruction::CreateGetTupleElement(inner_tuple->shape(), tuple, 0));
253 
254   BuildModuleAndRunAnalysis(builder.Build());
255 
256   auto& points_to_set = points_to_analysis_->GetPointsToSet(get_tuple_element);
257   EXPECT_EQ(3, points_to_set.size());
258   EXPECT_FALSE(points_to_set.IsAmbiguous());
259   EXPECT_TRUE(points_to_set.IsDistinct());
260   ExpectHasTopLevelBuffers(points_to_set.CreateFlattenedSet(),
261                            {constant1, constant2, inner_tuple});
262   ExpectHasTopLevelBuffers(points_to_set.element({}), {inner_tuple});
263 
264   EXPECT_THAT(points_to_set.tuple_sources({}),
265               UnorderedElementsAre(inner_tuple));
266 }
267 
TEST_F(TuplePointsToAnalysisTest,AddDependency)268 TEST_F(TuplePointsToAnalysisTest, AddDependency) {
269   auto builder = HloComputation::Builder(TestName());
270   auto constant = builder.AddInstruction(
271       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
272   auto token = builder.AddInstruction(HloInstruction::CreateToken());
273   auto add_dependency = builder.AddInstruction(
274       HloInstruction::CreateAddDependency(constant, token));
275   BuildModuleAndRunAnalysis(builder.Build());
276 
277   auto& points_to_set = points_to_analysis_->GetPointsToSet(add_dependency);
278   EXPECT_EQ(1, points_to_set.size());
279   EXPECT_FALSE(points_to_set.IsAmbiguous());
280   EXPECT_TRUE(points_to_set.IsDistinct());
281   ExpectHasTopLevelBuffers(points_to_set.CreateFlattenedSet(), {constant});
282 }
283 
TEST_F(TuplePointsToAnalysisTest,DuplicatedElement)284 TEST_F(TuplePointsToAnalysisTest, DuplicatedElement) {
285   // Create a tuple which contains duplicate elements.
286   auto builder = HloComputation::Builder(TestName());
287   auto constant = builder.AddInstruction(
288       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
289   auto tuple = builder.AddInstruction(
290       HloInstruction::CreateTuple({constant, constant, constant}));
291 
292   BuildModuleAndRunAnalysis(builder.Build());
293 
294   EXPECT_EQ(2, points_to_analysis_->GetPointsToSet(tuple).size());
295   EXPECT_FALSE(points_to_analysis_->GetPointsToSet(tuple).IsAmbiguous());
296   EXPECT_FALSE(points_to_analysis_->GetPointsToSet(tuple).IsDistinct());
297   ExpectHasTopLevelBuffers(
298       points_to_analysis_->GetPointsToSet(tuple).element({}), {tuple});
299   ExpectHasTopLevelBuffers(
300       points_to_analysis_->GetPointsToSet(tuple).CreateFlattenedSet(),
301       {constant, tuple});
302 }
303 
TEST_F(TuplePointsToAnalysisTest,TupleCopy)304 TEST_F(TuplePointsToAnalysisTest, TupleCopy) {
305   // Create a copy (HloOpcode::kCopy) of a tuple. The points to sets should be
306   // the same.
307   auto builder = HloComputation::Builder(TestName());
308   auto constant1 = builder.AddInstruction(
309       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
310   auto constant2 = builder.AddInstruction(
311       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
312   auto tuple = builder.AddInstruction(
313       HloInstruction::CreateTuple({constant1, constant2}));
314   auto copy = builder.AddInstruction(
315       HloInstruction::CreateUnary(tuple->shape(), HloOpcode::kCopy, tuple));
316 
317   BuildModuleAndRunAnalysis(builder.Build());
318 
319   EXPECT_FALSE(points_to_analysis_->GetPointsToSet(copy).IsAmbiguous());
320   EXPECT_TRUE(points_to_analysis_->GetPointsToSet(copy).IsDistinct());
321   ExpectHasTopLevelBuffers(
322       points_to_analysis_->GetPointsToSet(tuple).CreateFlattenedSet(),
323       {constant1, constant2, tuple});
324   ExpectHasTopLevelBuffers(
325       points_to_analysis_->GetPointsToSet(copy).element({}), {copy});
326   ExpectHasTopLevelBuffers(
327       points_to_analysis_->GetPointsToSet(copy).CreateFlattenedSet(),
328       {constant1, constant2, copy});
329 }
330 
TEST_F(TuplePointsToAnalysisTest,SendAndSendDone)331 TEST_F(TuplePointsToAnalysisTest, SendAndSendDone) {
332   // Send forwards its operand to the output tuple at {0}.
333   auto builder = HloComputation::Builder(TestName());
334   auto constant = builder.AddInstruction(
335       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
336   auto token = builder.AddInstruction(HloInstruction::CreateToken());
337   auto send = builder.AddInstruction(
338       HloInstruction::CreateSend(constant, token, /*channel_id=*/0));
339   auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send));
340 
341   BuildModuleAndRunAnalysis(builder.Build());
342 
343   EXPECT_FALSE(points_to_analysis_->GetPointsToSet(send).IsAmbiguous());
344   EXPECT_TRUE(points_to_analysis_->GetPointsToSet(send).IsDistinct());
345   EXPECT_FALSE(points_to_analysis_->GetPointsToSet(send_done).IsAmbiguous());
346   EXPECT_TRUE(points_to_analysis_->GetPointsToSet(send_done).IsDistinct());
347 
348   ExpectHasTopLevelBuffers(
349       points_to_analysis_->GetPointsToSet(send).element({}), {send});
350   ExpectHasTopLevelBuffers(
351       points_to_analysis_->GetPointsToSet(send).element({0}), {constant});
352   ExpectHasTopLevelBuffers(
353       points_to_analysis_->GetPointsToSet(send_done).CreateFlattenedSet(),
354       {send_done});
355   ExpectHasBufferAliases(constant, {}, {{constant, {}}, {send, {0}}});
356 }
357 
TEST_F(TuplePointsToAnalysisTest,RecvAndRecvDone)358 TEST_F(TuplePointsToAnalysisTest, RecvAndRecvDone) {
359   // RecvDone forwards its operand tuple element at {0} to the output.
360   auto builder = HloComputation::Builder(TestName());
361   auto token = builder.AddInstruction(HloInstruction::CreateToken());
362   auto recv = builder.AddInstruction(HloInstruction::CreateRecv(
363       ShapeUtil::MakeShape(F32, {1, 2, 3}), token, /*channel_id=*/0));
364   auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv));
365 
366   BuildModuleAndRunAnalysis(builder.Build());
367 
368   EXPECT_FALSE(points_to_analysis_->GetPointsToSet(recv).IsAmbiguous());
369   EXPECT_TRUE(points_to_analysis_->GetPointsToSet(recv).IsDistinct());
370   EXPECT_FALSE(points_to_analysis_->GetPointsToSet(recv_done).IsAmbiguous());
371   EXPECT_TRUE(points_to_analysis_->GetPointsToSet(recv_done).IsDistinct());
372 
373   ExpectHasTopLevelBuffers(
374       points_to_analysis_->GetPointsToSet(recv).element({}), {recv});
375   ExpectHasBufferAliases(recv, {0}, {{recv, {0}}, {recv_done, {0}}});
376 }
377 
TEST_F(TuplePointsToAnalysisTest,TupleSelect)378 TEST_F(TuplePointsToAnalysisTest, TupleSelect) {
379   // Select from two different tuples. This should create an ambiguous points to
380   // set containing the union of both sides.
381   auto builder = HloComputation::Builder(TestName());
382   auto constant1 = builder.AddInstruction(
383       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
384   auto constant2 = builder.AddInstruction(
385       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
386   auto tuple1 = builder.AddInstruction(
387       HloInstruction::CreateTuple({constant1, constant2}));
388   auto tuple2 = builder.AddInstruction(
389       HloInstruction::CreateTuple({constant2, constant2}));
390 
391   auto pred = builder.AddInstruction(
392       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
393   auto select = builder.AddInstruction(HloInstruction::CreateTernary(
394       tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2));
395 
396   BuildModuleAndRunAnalysis(builder.Build());
397 
398   auto& points_to_set = points_to_analysis_->GetPointsToSet(select);
399   EXPECT_EQ(3, points_to_set.size());
400   EXPECT_TRUE(points_to_set.IsAmbiguous());
401   EXPECT_FALSE(points_to_set.IsDistinct());
402   ExpectHasTopLevelBuffers(points_to_set.element({}), {select});
403   ExpectHasTopLevelBuffers(points_to_set.element({0}), {constant1, constant2});
404   ExpectHasTopLevelBuffers(points_to_set.element({1}), {constant2});
405   ExpectHasTopLevelBuffers(points_to_set.CreateFlattenedSet(),
406                            {constant1, constant2, select});
407 }
408 
TEST_F(TuplePointsToAnalysisTest,SelectTupleParameters)409 TEST_F(TuplePointsToAnalysisTest, SelectTupleParameters) {
410   // Create a Select which selects between two tuple parameters. Verify the
411   // points-to sets and tuple sources are properly set.
412   Shape tuple_shape = ShapeUtil::MakeTupleShape(
413       {ShapeUtil::MakeShape(F32, {1, 2, 3}), ShapeUtil::MakeShape(U32, {5})});
414 
415   auto builder = HloComputation::Builder(TestName());
416   auto param0 = builder.AddInstruction(
417       HloInstruction::CreateParameter(0, tuple_shape, "param0"));
418   auto param1 = builder.AddInstruction(
419       HloInstruction::CreateParameter(1, tuple_shape, "param1"));
420   auto pred = builder.AddInstruction(
421       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
422   auto select = builder.AddInstruction(HloInstruction::CreateTernary(
423       tuple_shape, HloOpcode::kTupleSelect, pred, param0, param1));
424   auto copy = builder.AddInstruction(
425       HloInstruction::CreateUnary(tuple_shape, HloOpcode::kCopy, select));
426 
427   BuildModuleAndRunAnalysis(builder.Build());
428 
429   // The points-to set of each element of a tuple parameters should be itself
430   // with the appropriate index.
431   ExpectHasBuffers(points_to_analysis_->GetPointsToSet(param0).element({}),
432                    {GetBuffer(param0, {})});
433   ExpectHasBuffers(points_to_analysis_->GetPointsToSet(param0).element({0}),
434                    {GetBuffer(param0, {0})});
435   ExpectHasBuffers(points_to_analysis_->GetPointsToSet(param0).element({1}),
436                    {GetBuffer(param0, {1})});
437 
438   // Select's point-to set of its subelements should be the respective
439   // subelements of param0 and param1. The top-level buffer, however, does not
440   // alias as it is created by the select instruction.
441   ExpectHasBuffers(points_to_analysis_->GetPointsToSet(select).element({}),
442                    {GetBuffer(select, {})});
443   ExpectHasBuffers(points_to_analysis_->GetPointsToSet(select).element({0}),
444                    {GetBuffer(param0, {0}), GetBuffer(param1, {0})});
445   ExpectHasBuffers(points_to_analysis_->GetPointsToSet(select).element({1}),
446                    {GetBuffer(param0, {1}), GetBuffer(param1, {1})});
447 
448   // Copy should be identical to select other than the top-level buffer.
449   ExpectHasBuffers(points_to_analysis_->GetPointsToSet(copy).element({}),
450                    {GetBuffer(copy, {})});
451   ExpectHasBuffers(points_to_analysis_->GetPointsToSet(copy).element({0}),
452                    {GetBuffer(param0, {0}), GetBuffer(param1, {0})});
453   ExpectHasBuffers(points_to_analysis_->GetPointsToSet(copy).element({1}),
454                    {GetBuffer(param0, {1}), GetBuffer(param1, {1})});
455 }
456 
TEST_F(TuplePointsToAnalysisTest,UnambiguousTupleSelect)457 TEST_F(TuplePointsToAnalysisTest, UnambiguousTupleSelect) {
458   // Select from two identical tuples. The result should not be ambiguous.
459   auto builder = HloComputation::Builder(TestName());
460   auto constant1 = builder.AddInstruction(
461       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
462   auto constant2 = builder.AddInstruction(
463       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
464   auto tuple1 = builder.AddInstruction(
465       HloInstruction::CreateTuple({constant1, constant2}));
466   auto tuple2 = builder.AddInstruction(
467       HloInstruction::CreateTuple({constant1, constant2}));
468 
469   auto pred = builder.AddInstruction(
470       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
471   auto select = builder.AddInstruction(HloInstruction::CreateTernary(
472       tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2));
473 
474   BuildModuleAndRunAnalysis(builder.Build());
475 
476   auto& points_to_set = points_to_analysis_->GetPointsToSet(select);
477   EXPECT_EQ(3, points_to_set.size());
478   EXPECT_FALSE(points_to_set.IsAmbiguous());
479   EXPECT_TRUE(points_to_set.IsDistinct());
480   ExpectHasTopLevelBuffers(points_to_set.element({}), {select});
481   ExpectHasTopLevelBuffers(points_to_set.element({0}), {constant1});
482   ExpectHasTopLevelBuffers(points_to_set.element({1}), {constant2});
483   ExpectHasTopLevelBuffers(points_to_set.CreateFlattenedSet(),
484                            {constant1, constant2, select});
485 }
486 
TEST_F(TuplePointsToAnalysisTest,NestedTupleSelect)487 TEST_F(TuplePointsToAnalysisTest, NestedTupleSelect) {
488   // Select from nested tuples. Verify that the nested points-to sets contain
489   // the right values.
490   auto builder = HloComputation::Builder(TestName());
491   auto constant1 = builder.AddInstruction(
492       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
493   auto constant2 = builder.AddInstruction(
494       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
495   auto inner_tuple1 = builder.AddInstruction(
496       HloInstruction::CreateTuple({constant1, constant2}));
497   auto inner_tuple2 = builder.AddInstruction(
498       HloInstruction::CreateTuple({constant2, constant2}));
499 
500   auto tuple1 =
501       builder.AddInstruction(HloInstruction::CreateTuple({inner_tuple1}));
502   auto tuple2 =
503       builder.AddInstruction(HloInstruction::CreateTuple({inner_tuple2}));
504 
505   auto pred = builder.AddInstruction(
506       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
507   auto select = builder.AddInstruction(HloInstruction::CreateTernary(
508       tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2));
509 
510   BuildModuleAndRunAnalysis(builder.Build());
511 
512   auto& points_to_set = points_to_analysis_->GetPointsToSet(select);
513   EXPECT_EQ(5, points_to_set.size());
514   EXPECT_TRUE(points_to_set.IsAmbiguous());
515   EXPECT_FALSE(points_to_set.IsDistinct());
516 
517   // Verify points-to set.
518   ExpectHasTopLevelBuffers(points_to_set.element({}), {select});
519   ExpectHasTopLevelBuffers(points_to_set.element({0}),
520                            {inner_tuple1, inner_tuple2});
521   ExpectHasTopLevelBuffers(points_to_set.element({0, 0}),
522                            {constant1, constant2});
523   ExpectHasTopLevelBuffers(points_to_set.element({0, 1}), {constant2});
524 
525   // Verify tuple sources.
526   EXPECT_THAT(points_to_set.tuple_sources({}),
527               UnorderedElementsAre(tuple1, tuple2));
528   EXPECT_THAT(points_to_set.tuple_sources({0}),
529               UnorderedElementsAre(inner_tuple1, inner_tuple2));
530   EXPECT_EQ(0, points_to_set.tuple_sources({0, 0}).size());
531   EXPECT_EQ(0, points_to_set.tuple_sources({0, 1}).size());
532 }
533 
TEST_F(TuplePointsToAnalysisTest,TupleWithBitcast)534 TEST_F(TuplePointsToAnalysisTest, TupleWithBitcast) {
535   // Bitcast is an alias of its operand. A tuple with a bitcast element should
536   // have the operand of the bitcast in its points-to set.
537   auto builder = HloComputation::Builder(TestName());
538   auto constant1 = builder.AddInstruction(
539       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
540   auto constant2 = builder.AddInstruction(
541       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
542   auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary(
543       constant2->shape(), HloOpcode::kBitcast, constant2));
544   auto tuple =
545       builder.AddInstruction(HloInstruction::CreateTuple({constant1, bitcast}));
546 
547   BuildModuleAndRunAnalysis(builder.Build());
548 
549   EXPECT_EQ(1, points_to_analysis_->GetPointsToSet(bitcast).size());
550   ExpectHasTopLevelBuffers(
551       points_to_analysis_->GetPointsToSet(bitcast).element({}), {constant2});
552   EXPECT_TRUE(
553       points_to_analysis_->GetPointsToSet(bitcast).tuple_sources({}).empty());
554 
555   EXPECT_EQ(3, points_to_analysis_->GetPointsToSet(tuple).size());
556   EXPECT_FALSE(points_to_analysis_->GetPointsToSet(tuple).IsAmbiguous());
557   EXPECT_THAT(points_to_analysis_->GetPointsToSet(tuple).tuple_sources({}),
558               UnorderedElementsAre(tuple));
559 
560   ExpectHasTopLevelBuffers(
561       points_to_analysis_->GetPointsToSet(tuple).CreateFlattenedSet(),
562       {constant1, constant2, tuple});
563   ExpectHasTopLevelBuffers(
564       points_to_analysis_->GetPointsToSet(tuple).element({}), {tuple});
565   ExpectHasTopLevelBuffers(
566       points_to_analysis_->GetPointsToSet(tuple).element({0}), {constant1});
567   ExpectHasTopLevelBuffers(
568       points_to_analysis_->GetPointsToSet(tuple).element({1}), {constant2});
569 }
570 
TEST_F(TuplePointsToAnalysisTest,PointsToTupleConstantElements)571 TEST_F(TuplePointsToAnalysisTest, PointsToTupleConstantElements) {
572   // Construct a tuple constant and kCopy it. Verify the points-to set of the
573   // copy correctly correctly points into the nested elements of the constant.
574   auto builder = HloComputation::Builder(TestName());
575   Literal elements[] = {LiteralUtil::CreateR2<float>({{1.0}, {2.0}}),
576                         LiteralUtil::CreateR1<float>({2.0, 42})};
577   auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant(
578       LiteralUtil::MakeTuple({&elements[0], &elements[1]})));
579   auto copy = builder.AddInstruction(HloInstruction::CreateUnary(
580       tuple_constant->shape(), HloOpcode::kCopy, tuple_constant));
581 
582   BuildModuleAndRunAnalysis(builder.Build());
583 
584   auto& points_to_set = points_to_analysis_->GetPointsToSet(copy);
585 
586   ExpectHasBuffers(points_to_set.element({}), {GetBuffer(copy, {})});
587   ExpectHasBuffers(points_to_set.element({0}),
588                    {GetBuffer(tuple_constant, {0})});
589   ExpectHasBuffers(points_to_set.element({1}),
590                    {GetBuffer(tuple_constant, {1})});
591 }
592 
TEST_F(TuplePointsToAnalysisTest,BufferAliases)593 TEST_F(TuplePointsToAnalysisTest, BufferAliases) {
594   // Create a nested tuple in which individual elements appear multiple
595   // times. Verify buffer alias sets.
596   auto builder = HloComputation::Builder(TestName());
597   auto constant1 = builder.AddInstruction(
598       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
599   auto constant2 = builder.AddInstruction(
600       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
601   auto inner_tuple = builder.AddInstruction(
602       HloInstruction::CreateTuple({constant1, constant2}));
603   auto tuple = builder.AddInstruction(
604       HloInstruction::CreateTuple({inner_tuple, constant2}));
605 
606   BuildModuleAndRunAnalysis(builder.Build());
607 
608   ExpectHasBufferAliases(
609       constant1, /*index=*/{},
610       {{constant1, {}}, {inner_tuple, {0}}, {tuple, {0, 0}}});
611   ExpectHasBufferAliases(
612       constant2, /*index=*/{},
613       {{constant2, {}}, {inner_tuple, {1}}, {tuple, {0, 1}}, {tuple, {1}}});
614   ExpectHasBufferAliases(inner_tuple, /*index=*/{},
615                          {{inner_tuple, {}}, {tuple, {0}}});
616   ExpectHasBufferAliases(tuple, /*index=*/{}, {{tuple, {}}});
617 }
618 
619 class FusionPointsToAnalysisTest : public TuplePointsToAnalysisTest {
620  protected:
621   // Builds a computation, runs instruction fusion HloPass, runs points-to
622   // analysis, then checks for expected results (see unit test cases for
623   // example computation graphs).
Run(const bool add_additional_gte0_user)624   void Run(const bool add_additional_gte0_user) {
625     Shape input_shape = ShapeUtil::MakeShape(F32, {8});
626     Shape update_shape = ShapeUtil::MakeShape(F32, {3});
627     Shape starts_shape = ShapeUtil::MakeShape(S32, {});
628     Shape tuple_shape =
629         ShapeUtil::MakeTupleShape({input_shape, update_shape, starts_shape});
630 
631     auto builder = HloComputation::Builder(TestName());
632     // Create tuple-shaped parameter.
633     auto tuple_param0 = builder.AddInstruction(
634         HloInstruction::CreateParameter(0, tuple_shape, "param0"));
635     // Create 'tuple_element1' = GetTupleElement(tuple_param0, 1).
636     auto tuple_element1 = builder.AddInstruction(
637         HloInstruction::CreateGetTupleElement(update_shape, tuple_param0, 1));
638     auto ones = builder.AddInstruction(HloInstruction::CreateConstant(
639         LiteralUtil::CreateR1<float>({1.f, 1.f, 1.f, 1.f})));
640     // Create 'update' = Add(GetTupleElement(tuple_param0, 1), ones)
641     auto update = builder.AddInstruction(HloInstruction::CreateBinary(
642         update_shape, HloOpcode::kAdd, tuple_element1, ones));
643     // Create 'input' = GetTupleElement(tuple_param0, 0).
644     auto input = builder.AddInstruction(
645         HloInstruction::CreateGetTupleElement(input_shape, tuple_param0, 0));
646 
647     if (add_additional_gte0_user) {
648       // Create 'slice' as an additional user of 'input'.
649       auto slice = builder.AddInstruction(
650           HloInstruction::CreateSlice(update_shape, input, {0}, {3}, {1}));
651       // Modify 'update' to take 'slice' output.
652       update = builder.AddInstruction(HloInstruction::CreateBinary(
653           update_shape, HloOpcode::kAdd, update, slice));
654     }
655 
656     // Create slice 'starts' = GetTupleElement(tuple_param0, 2).
657     auto starts = builder.AddInstruction(
658         HloInstruction::CreateGetTupleElement(starts_shape, tuple_param0, 2));
659     // Update 'input' with 'update' at dynamic 'starts' indices.
660     builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
661         input_shape, input, update, {starts}));
662 
663     // Build computation and add it to module as entry computation.
664     BuildModule(builder.Build());
665     // Run instruction fusion HloPass.
666     EXPECT_TRUE(InstructionFusion(InstructionFusion::IsExpensive)
667                     .Run(module_.get())
668                     .ValueOrDie());
669     // Get computation root instruction (should be a kFusion).
670     auto* fusion = module_->entry_computation()->root_instruction();
671     EXPECT_THAT(fusion, op::Fusion(tuple_param0));
672     // Run points-to analysis (should include fused instructions from 'fusion').
673     RunAnalysis();
674 
675     // Check points-to set of fusion parameter associated with 'tuple_param0'.
676     auto* fusion_param = GetFusionParameterForOperand(fusion, tuple_param0);
677     ExpectHasBuffers(
678         points_to_analysis_->GetPointsToSet(fusion_param).element({}),
679         {GetBuffer(fusion_param, {})});
680     ExpectHasBuffers(
681         points_to_analysis_->GetPointsToSet(fusion_param).element({0}),
682         {GetBuffer(fusion_param, {0})});
683     ExpectHasBuffers(
684         points_to_analysis_->GetPointsToSet(fusion_param).element({1}),
685         {GetBuffer(fusion_param, {1})});
686     ExpectHasBuffers(
687         points_to_analysis_->GetPointsToSet(fusion_param).element({2}),
688         {GetBuffer(fusion_param, {2})});
689 
690     // Check that Gte at tuple_index = 0 points-to fusion_param({0})
691     auto fused_gte0 = GetUniqueFusionParameterUserAt(fusion_param, 0);
692     ExpectHasBuffers(
693         points_to_analysis_->GetPointsToSet(fused_gte0).element({}),
694         {GetBuffer(fusion_param, {0})});
695     // Check that Gte at tuple_index = 1 points-to fusion_param({1})
696     auto fused_gte1 = GetUniqueFusionParameterUserAt(fusion_param, 1);
697     ExpectHasBuffers(
698         points_to_analysis_->GetPointsToSet(fused_gte1).element({}),
699         {GetBuffer(fusion_param, {1})});
700     // Check that Gte at tuple_index = 2 points-to fusion_param({2})
701     auto fused_gte2 = GetUniqueFusionParameterUserAt(fusion_param, 2);
702     ExpectHasBuffers(
703         points_to_analysis_->GetPointsToSet(fused_gte2).element({}),
704         {GetBuffer(fusion_param, {2})});
705 
706     // Check buffer aliases of 'fusion_param' at shape index {0}.
707     ExpectHasBufferAliases(fusion_param, /*index=*/{0},
708                            {{fusion_param, {0}}, {fused_gte0, {}}});
709     // Check buffer aliases of 'fusion_param' at shape index {1}.
710     ExpectHasBufferAliases(fusion_param, /*index=*/{1},
711                            {{fusion_param, {1}}, {fused_gte1, {}}});
712     // Check buffer aliases of 'fusion_param' at shape index {2}.
713     ExpectHasBufferAliases(fusion_param, /*index=*/{2},
714                            {{fusion_param, {2}}, {fused_gte2, {}}});
715 
716     // Check number of users of 'fusion_param' aliases at shape index {0}.
717     ExpectNumUsersOfAliases(fusion_param, {0},
718                             add_additional_gte0_user ? 2 : 1);
719   }
720 
721   // Returns fusion parameter (from 'fusion.fused_instructions') corresponding
722   // to fusion 'operand'.
GetFusionParameterForOperand(HloInstruction * fusion,HloInstruction * operand)723   HloInstruction* GetFusionParameterForOperand(HloInstruction* fusion,
724                                                HloInstruction* operand) {
725     auto it = absl::c_find_if(
726         fusion->fused_instructions(), [&](const HloInstruction* fused) {
727           return fused->opcode() == HloOpcode::kParameter &&
728                  fusion->operand(fused->parameter_number()) == operand;
729         });
730     CHECK(it != fusion->fused_instructions().end());
731     return *it;
732   }
733 
734   // Returns all users of 'fusion_paran' at 'tuple_index'.
GetFusionParameterUsersAt(HloInstruction * fusion_param,int64 tuple_index)735   std::vector<HloInstruction*> GetFusionParameterUsersAt(
736       HloInstruction* fusion_param, int64 tuple_index) {
737     CHECK(fusion_param->shape().IsTuple());
738     std::vector<HloInstruction*> users_at_tuple_index;
739     for (auto user : fusion_param->users()) {
740       CHECK_EQ(HloOpcode::kGetTupleElement, user->opcode());
741       if (user->tuple_index() == tuple_index) {
742         users_at_tuple_index.push_back(user);
743       }
744     }
745     return users_at_tuple_index;
746   }
747 
748   // Returns the unique user of 'fusion_param' at 'tuple_index'.
GetUniqueFusionParameterUserAt(HloInstruction * fusion_param,int64 tuple_index)749   HloInstruction* GetUniqueFusionParameterUserAt(HloInstruction* fusion_param,
750                                                  int64 tuple_index) {
751     std::vector<HloInstruction*> users =
752         GetFusionParameterUsersAt(fusion_param, tuple_index);
753     CHECK_EQ(1, users.size());
754     return users[0];
755   }
756 
757   // Checks that the count of all users of all aliases of 'instruction' at
758   // 'index' match 'expected_num_users'.
ExpectNumUsersOfAliases(const HloInstruction * instruction,const ShapeIndex & index,const int64 expected_num_users)759   void ExpectNumUsersOfAliases(const HloInstruction* instruction,
760                                const ShapeIndex& index,
761                                const int64 expected_num_users) {
762     const auto* buffer = GetBuffer(instruction, index);
763     int64 num_users = 0;
764     for (const auto& alias : points_to_analysis_->GetBufferAliases(*buffer)) {
765       for (auto user : alias.instruction()->users()) {
766         if (user->opcode() == HloOpcode::kGetTupleElement && !index.empty()) {
767           // Gte instructions only access the top-level buffer of their operand.
768           continue;
769         }
770         ++num_users;
771       }
772     }
773     EXPECT_EQ(expected_num_users, num_users);
774   }
775 };
776 
777 // Tests the points-to set of tuple-shaped fusion parameter 0 and all GTE users.
778 // Tests the alias set of tuple-shaped fusion parameter 0 at all shape indices.
779 // Tests that there is a single user of the aliases of tuple-shaped fusion
780 // parameter 0 at shape index {0}.
781 //
782 //             Param0    Const
783 //                 \      /
784 //                  Fusion
785 //                 /      \
786 //        FusionParam0   FusionParam1
787 //        /     |    \       |
788 //     Gte(0) Gte(2) Gte(1)  /
789 //        \     |      \    /
790 //         \    |       Add
791 //          \   |        /
792 //           \0 |2      /1
793 //          DynamicUpdateSlice  // fused root.
794 //
TEST_F(FusionPointsToAnalysisTest,FusionParam0OneUser)795 TEST_F(FusionPointsToAnalysisTest, FusionParam0OneUser) {
796   Run(/*add_additional_gte0_user=*/false);
797 }
798 
799 // Tests the points-to set of tuple-shaped fusion parameter 0 and all GTE users.
800 // Tests the alias set of tuple-shaped fusion parameter 0 at all shape indices.
801 // Tests that there are two users of the aliases of tuple-shaped fusion
802 // parameter 0 at shape index {0}.
803 //
804 //             Param0    Const
805 //                 \      /
806 //                  Fusion
807 //                 /      \
808 //        FusionParam0   FusionParam1
809 //        /     |    \       |
810 //     Gte(2) Gte(0) Gte(1)  /
811 //        \     |      \    /
812 //         \    |\      Add
813 //          \   | \      /
814 //           |  | Slice /
815 //           |  |   \  /
816 //           |  |   Add
817 //           |  |    |
818 //           |2 |0   |1
819 //          DynamicUpdateSlice  // fused root.
820 //
TEST_F(FusionPointsToAnalysisTest,FusionParam0TwoUsers)821 TEST_F(FusionPointsToAnalysisTest, FusionParam0TwoUsers) {
822   Run(/*add_additional_gte0_user=*/true);
823 }
824 
825 class PointsToAnalysisTestBase : public HloTestBase {
826  protected:
BuildModule(std::unique_ptr<HloComputation> computation)827   void BuildModule(std::unique_ptr<HloComputation> computation) {
828     module_ = CreateNewUnverifiedModule();
829     computation_ = module_->AddEntryComputation(std::move(computation));
830   }
831 
RunAnalysis()832   void RunAnalysis() {
833     CHECK_NOTNULL(module_.get());
834     points_to_analysis_ =
835         TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie();
836   }
837 
BuildModuleAndRunAnalysis(std::unique_ptr<HloComputation> computation)838   void BuildModuleAndRunAnalysis(std::unique_ptr<HloComputation> computation) {
839     BuildModule(std::move(computation));
840     RunAnalysis();
841   }
842 
843   std::unique_ptr<HloModule> module_;
844   HloComputation* computation_ = nullptr;
845   std::unique_ptr<TuplePointsToAnalysis> points_to_analysis_;
846 };
847 
848 class DoesNotUseOperandBufferTest : public PointsToAnalysisTestBase {};
849 
TEST_F(DoesNotUseOperandBufferTest,GetTupleElement)850 TEST_F(DoesNotUseOperandBufferTest, GetTupleElement) {
851   auto builder = HloComputation::Builder(TestName());
852 
853   Shape elem_shape = ShapeUtil::MakeShape(F32, {8});
854   auto tuple = builder.AddInstruction(HloInstruction::CreateParameter(
855       0, ShapeUtil::MakeTupleShape({elem_shape, elem_shape}), "tuple"));
856   auto gte0 = builder.AddInstruction(
857       HloInstruction::CreateGetTupleElement(elem_shape, tuple, 0));
858   auto gte1 = builder.AddInstruction(
859       HloInstruction::CreateGetTupleElement(elem_shape, tuple, 1));
860   builder.AddInstruction(
861       HloInstruction::CreateBinary(elem_shape, HloOpcode::kAdd, gte0, gte1));
862 
863   BuildModuleAndRunAnalysis(builder.Build());
864 
865   // GetTupleElement instructions only access the top-level buffer of their
866   // operand.
867   EXPECT_TRUE(points_to_analysis_->DoesNotUseOperandBuffer(tuple, {0}, gte0));
868   EXPECT_TRUE(points_to_analysis_->DoesNotUseOperandBuffer(tuple, {1}, gte1));
869   EXPECT_FALSE(points_to_analysis_->DoesNotUseOperandBuffer(tuple, {}, gte0));
870   EXPECT_FALSE(points_to_analysis_->DoesNotUseOperandBuffer(tuple, {}, gte1));
871 }
872 
TEST_F(DoesNotUseOperandBufferTest,FusedDynamicUpdateSlice)873 TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) {
874   auto builder = HloComputation::Builder(TestName());
875 
876   Shape data_shape = ShapeUtil::MakeShape(F32, {8});
877   auto tuple = builder.AddInstruction(HloInstruction::CreateParameter(
878       0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple"));
879   auto gte0 = builder.AddInstruction(
880       HloInstruction::CreateGetTupleElement(data_shape, tuple, 0));
881   auto gte1 = builder.AddInstruction(
882       HloInstruction::CreateGetTupleElement(data_shape, tuple, 1));
883 
884   // Create a DynamicUpdateSlice instruction of tuple element 1.
885   auto starts = builder.AddInstruction(
886       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(2)));
887   auto update = builder.AddInstruction(HloInstruction::CreateConstant(
888       LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
889   auto dynamic_update_slice =
890       builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
891           data_shape, gte1, update, {starts}));
892   builder.AddInstruction(
893       HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
894 
895   BuildModule(builder.Build());
896   auto fusion = computation_->CreateFusionInstruction(
897       {dynamic_update_slice, starts, update, gte1},
898       HloInstruction::FusionKind::kLoop);
899   RunAnalysis();
900 
901   // The fusion instruction never uses tuple element 0, but does use element 1.
902   EXPECT_TRUE(points_to_analysis_->DoesNotUseOperandBuffer(tuple, {0}, fusion));
903   EXPECT_FALSE(
904       points_to_analysis_->DoesNotUseOperandBuffer(tuple, {1}, fusion));
905 }
906 
907 class CanShareOperandBufferWithUserTest : public PointsToAnalysisTestBase {};
908 
TEST_F(CanShareOperandBufferWithUserTest,ElementWiseSameShape)909 TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) {
910   auto builder = HloComputation::Builder(TestName());
911 
912   Shape shape = ShapeUtil::MakeShape(F32, {8});
913   auto param = builder.AddInstruction(
914       HloInstruction::CreateParameter(0, shape, "param"));
915   auto exp = builder.AddInstruction(
916       HloInstruction::CreateUnary(shape, HloOpcode::kExp, param));
917   auto log = builder.AddInstruction(
918       HloInstruction::CreateUnary(shape, HloOpcode::kLog, exp));
919 
920   BuildModuleAndRunAnalysis(builder.Build());
921 
922   EXPECT_TRUE(
923       points_to_analysis_->CanShareOperandBufferWithUser(param, {}, exp, {}));
924   EXPECT_TRUE(
925       points_to_analysis_->CanShareOperandBufferWithUser(exp, {}, log, {}));
926 }
927 
TEST_F(CanShareOperandBufferWithUserTest,ElementWiseDifferentShape)928 TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) {
929   auto builder = HloComputation::Builder(TestName());
930 
931   Shape in_shape = ShapeUtil::MakeShape(F32, {8});
932   Shape out_shape = ShapeUtil::MakeShape(PRED, {8});
933   auto param0 = builder.AddInstruction(
934       HloInstruction::CreateParameter(0, in_shape, "param0"));
935   auto param1 = builder.AddInstruction(
936       HloInstruction::CreateParameter(1, in_shape, "param1"));
937   auto result = builder.AddInstruction(HloInstruction::CreateCompare(
938       out_shape, param0, param1, ComparisonDirection::kEq));
939 
940   BuildModuleAndRunAnalysis(builder.Build());
941 
942   EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(param0, {},
943                                                                   result, {}));
944   EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(param1, {},
945                                                                   result, {}));
946 }
947 
TEST_F(CanShareOperandBufferWithUserTest,CopyShares)948 TEST_F(CanShareOperandBufferWithUserTest, CopyShares) {
949   auto builder = HloComputation::Builder(TestName());
950 
951   Shape shape = ShapeUtil::MakeShape(F32, {8});
952   auto param = builder.AddInstruction(
953       HloInstruction::CreateParameter(0, shape, "param"));
954   auto exp = builder.AddInstruction(
955       HloInstruction::CreateUnary(shape, HloOpcode::kExp, param));
956   auto copy = builder.AddInstruction(
957       HloInstruction::CreateUnary(shape, HloOpcode::kCopy, exp));
958 
959   BuildModuleAndRunAnalysis(builder.Build());
960 
961   EXPECT_TRUE(
962       points_to_analysis_->CanShareOperandBufferWithUser(param, {}, exp, {}));
963   EXPECT_TRUE(
964       points_to_analysis_->CanShareOperandBufferWithUser(exp, {}, copy, {}));
965 }
966 
TEST_F(CanShareOperandBufferWithUserTest,FusedDynamicUpdateSlice)967 TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) {
968   auto builder = HloComputation::Builder(TestName());
969 
970   Shape data_shape = ShapeUtil::MakeShape(F32, {8});
971   auto tuple = builder.AddInstruction(HloInstruction::CreateParameter(
972       0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple"));
973   auto gte0 = builder.AddInstruction(
974       HloInstruction::CreateGetTupleElement(data_shape, tuple, 0));
975   auto gte1 = builder.AddInstruction(
976       HloInstruction::CreateGetTupleElement(data_shape, tuple, 1));
977 
978   // Create a DynamicUpdateSlice instruction of tuple element 1.
979   auto starts = builder.AddInstruction(
980       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(2)));
981   auto update = builder.AddInstruction(HloInstruction::CreateConstant(
982       LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
983   auto dynamic_update_slice =
984       builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
985           data_shape, gte1, update, {starts}));
986   builder.AddInstruction(
987       HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
988 
989   BuildModule(builder.Build());
990   auto fusion = computation_->CreateFusionInstruction(
991       {dynamic_update_slice, starts, update, gte1},
992       HloInstruction::FusionKind::kLoop);
993   RunAnalysis();
994 
995   // The fusion instruction can share with tuple element 1.
996   EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(tuple, {0},
997                                                                   fusion, {}));
998   EXPECT_TRUE(points_to_analysis_->CanShareOperandBufferWithUser(tuple, {1},
999                                                                  fusion, {}));
1000 }
1001 
TEST_F(CanShareOperandBufferWithUserTest,DynamicUpdateSliceCanShare)1002 TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) {
1003   auto builder = HloComputation::Builder(TestName());
1004 
1005   Shape data_shape = ShapeUtil::MakeShape(F32, {8});
1006   Shape update_shape = ShapeUtil::MakeShape(F32, {4});
1007   Shape starts_shape = ShapeUtil::MakeShape(S32, {});
1008   auto data = builder.AddInstruction(
1009       HloInstruction::CreateParameter(0, data_shape, "data"));
1010   auto update = builder.AddInstruction(
1011       HloInstruction::CreateParameter(1, update_shape, "update"));
1012   auto starts = builder.AddInstruction(
1013       HloInstruction::CreateParameter(2, starts_shape, "starts"));
1014   auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
1015       data_shape, data, update, {starts}));
1016 
1017   BuildModuleAndRunAnalysis(builder.Build());
1018 
1019   // The DynamicUpdateSlice instruction can share with the data operand, but not
1020   // with update or starts.
1021   EXPECT_TRUE(
1022       points_to_analysis_->CanShareOperandBufferWithUser(data, {}, dus, {}));
1023   EXPECT_FALSE(
1024       points_to_analysis_->CanShareOperandBufferWithUser(update, {}, dus, {}));
1025   EXPECT_FALSE(
1026       points_to_analysis_->CanShareOperandBufferWithUser(starts, {}, dus, {}));
1027 }
1028 
TEST_F(CanShareOperandBufferWithUserTest,ScatterCanShare)1029 TEST_F(CanShareOperandBufferWithUserTest, ScatterCanShare) {
1030   const char* hlo_text = R"(
1031     HloModule TensorFlowScatterV1
1032 
1033     update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
1034       lhs = s32[] parameter(0)
1035       ROOT rhs = s32[] parameter(1)
1036     }
1037 
1038     ENTRY main {
1039       operand = s32[3,3] parameter(0)
1040       indices = s32[2] parameter(1)
1041       updates = s32[2,3] parameter(2)
1042       ROOT scatter = s32[3,3] scatter(operand, indices, updates),
1043           to_apply=update_s32,
1044           update_window_dims={1},
1045           inserted_window_dims={0},
1046           scatter_dims_to_operand_dims={0},
1047           index_vector_dim=1
1048     }
1049   )";
1050   TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text));
1051   computation_ = module_->entry_computation();
1052   RunAnalysis();
1053 
1054   HloInstruction* operand_param = computation_->parameter_instruction(0);
1055   HloInstruction* indices_param = computation_->parameter_instruction(1);
1056   HloInstruction* updates_param = computation_->parameter_instruction(2);
1057   HloInstruction* scatter = computation_->root_instruction();
1058 
1059   EXPECT_TRUE(points_to_analysis_->CanShareOperandBufferWithUser(
1060       operand_param, {}, scatter, {}));
1061   EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(
1062       indices_param, {}, scatter, {}));
1063   EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(
1064       updates_param, {}, scatter, {}));
1065 }
1066 
TEST_F(CanShareOperandBufferWithUserTest,SortCanShare)1067 TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) {
1068   auto builder = HloComputation::Builder(TestName());
1069   module_ = CreateNewVerifiedModule();
1070 
1071   Shape keys_shape = ShapeUtil::MakeShape(F32, {8});
1072   auto keys = builder.AddInstruction(
1073       HloInstruction::CreateParameter(0, keys_shape, "keys"));
1074   TF_ASSERT_OK_AND_ASSIGN(
1075       auto* sort, MakeSortHlo(keys_shape, {keys}, 0, /*is_stable=*/false,
1076                               &builder, module_.get()));
1077 
1078   computation_ = module_->AddEntryComputation(builder.Build());
1079   RunAnalysis();
1080 
1081   EXPECT_TRUE(
1082       points_to_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {}));
1083 }
1084 
TEST_F(CanShareOperandBufferWithUserTest,SortCanShareWithTupleUser)1085 TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) {
1086   auto builder = HloComputation::Builder(TestName());
1087   module_ = CreateNewVerifiedModule();
1088 
1089   Shape keys_shape = ShapeUtil::MakeShape(F32, {8});
1090   Shape values_shape = ShapeUtil::MakeShape(F32, {8});
1091   auto keys = builder.AddInstruction(
1092       HloInstruction::CreateParameter(0, keys_shape, "keys"));
1093   auto values = builder.AddInstruction(
1094       HloInstruction::CreateParameter(1, values_shape, "values"));
1095   TF_ASSERT_OK_AND_ASSIGN(
1096       auto* sort,
1097       MakeSortHlo(ShapeUtil::MakeTupleShape({keys_shape, values_shape}),
1098                   {keys, values}, 0, /*is_stable=*/false, &builder,
1099                   module_.get()));
1100 
1101   computation_ = module_->AddEntryComputation(builder.Build());
1102   RunAnalysis();
1103 
1104   // The buffer for the keys can be shared with the first tuple entry.
1105   EXPECT_TRUE(
1106       points_to_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {0}));
1107   // The buffer for the values can be shared with the second tuple entry.
1108   EXPECT_TRUE(points_to_analysis_->CanShareOperandBufferWithUser(values, {},
1109                                                                  sort, {1}));
1110   // Verify that the buffers are not shared with the "wrong" tuple entry.
1111   EXPECT_FALSE(
1112       points_to_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {1}));
1113   EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(values, {},
1114                                                                   sort, {0}));
1115 }
1116 
TEST_F(CanShareOperandBufferWithUserTest,FusedDotAdd)1117 TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) {
1118   auto builder = HloComputation::Builder(TestName());
1119   Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
1120 
1121   auto a = builder.AddInstruction(HloInstruction::CreateConstant(
1122       LiteralUtil::CreateR2<float>({{1.0, 0.0}, {0.0, 1.0}})));
1123   auto b = builder.AddInstruction(HloInstruction::CreateConstant(
1124       LiteralUtil::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
1125 
1126   DotDimensionNumbers dot_dnums;
1127   dot_dnums.add_lhs_contracting_dimensions(1);
1128   dot_dnums.add_rhs_contracting_dimensions(0);
1129   PrecisionConfig precision_config;
1130   precision_config.mutable_operand_precision()->Resize(
1131       /*new_size=*/2, PrecisionConfig::DEFAULT);
1132   auto dot = builder.AddInstruction(
1133       HloInstruction::CreateDot(data_shape, a, b, dot_dnums, precision_config));
1134 
1135   auto one = builder.AddInstruction(
1136       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1137   auto add_operand = builder.AddInstruction(
1138       HloInstruction::CreateBroadcast(data_shape, one, {1}));
1139 
1140   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
1141       data_shape, HloOpcode::kAdd, dot, add_operand));
1142 
1143   BuildModule(builder.Build());
1144   auto fusion = computation_->CreateFusionInstruction(
1145       {add, dot}, HloInstruction::FusionKind::kOutput);
1146   RunAnalysis();
1147 
1148   // Output fused dot add should be able to share buffer with 'add_operand'.
1149   EXPECT_TRUE(points_to_analysis_->CanShareOperandBufferWithUser(
1150       add_operand, {}, fusion, {}));
1151 }
1152 
TEST_F(CanShareOperandBufferWithUserTest,OutputFusionCantAliasOperandBuffer)1153 TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) {
1154   auto builder = HloComputation::Builder(TestName());
1155   Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
1156 
1157   auto one = builder.AddInstruction(
1158       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1159   auto operand = builder.AddInstruction(
1160       HloInstruction::CreateBroadcast(data_shape, one, {1}));
1161 
1162   auto reverse = builder.AddInstruction(
1163       HloInstruction::CreateReverse(data_shape, operand, {0, 1}));
1164 
1165   auto two = builder.AddInstruction(HloInstruction::CreateConstant(
1166       LiteralUtil::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
1167 
1168   auto add = builder.AddInstruction(
1169       HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, reverse, two));
1170 
1171   BuildModule(builder.Build());
1172   auto fusion = computation_->CreateFusionInstruction(
1173       {add, two, reverse}, HloInstruction::FusionKind::kOutput);
1174   RunAnalysis();
1175 
1176   // Output fused operand->reverse->add cannot alias operand buffer 'operand'.
1177   EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(operand, {},
1178                                                                   fusion, {}));
1179 }
1180 
TEST_F(CanShareOperandBufferWithUserTest,WhileCanShare)1181 TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) {
1182   Shape data_shape = ShapeUtil::MakeShape(F32, {8});
1183 
1184   auto make_cond = [&data_shape]() {
1185     auto builder = HloComputation::Builder(TestName() + ".Cond");
1186     auto data = builder.AddInstruction(
1187         HloInstruction::CreateParameter(0, data_shape, "data"));
1188     builder.AddInstruction(HloInstruction::CreateCompare(
1189         ShapeUtil::MakeShape(PRED, {}), data, data, ComparisonDirection::kEq));
1190     return builder.Build();
1191   };
1192 
1193   auto make_body = [&data_shape]() {
1194     auto builder = HloComputation::Builder(TestName() + ".Body");
1195     auto data = builder.AddInstruction(
1196         HloInstruction::CreateParameter(0, data_shape, "data"));
1197     builder.AddInstruction(
1198         HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, data, data));
1199     return builder.Build();
1200   };
1201 
1202   module_ = CreateNewUnverifiedModule();
1203   HloComputation* cond_computation =
1204       module_->AddEmbeddedComputation(make_cond());
1205   HloComputation* body_computation =
1206       module_->AddEmbeddedComputation(make_body());
1207 
1208   auto builder = HloComputation::Builder(TestName());
1209   auto data = builder.AddInstruction(
1210       HloInstruction::CreateParameter(0, data_shape, "data"));
1211   auto whil = builder.AddInstruction(HloInstruction::CreateWhile(
1212       data_shape, cond_computation, body_computation, data));
1213   computation_ = module_->AddEntryComputation(builder.Build());
1214 
1215   RunAnalysis();
1216 
1217   // The While instruction can share with the data operand.
1218   EXPECT_TRUE(
1219       points_to_analysis_->CanShareOperandBufferWithUser(data, {}, whil, {}));
1220 }
1221 
1222 // Tests that Call can alias operand buffer if the only use of the operand
1223 // in the called computation is an elementwise instruction.
TEST_F(CanShareOperandBufferWithUserTest,CallToComputationWithFusionRoot)1224 TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) {
1225   Shape shape = ShapeUtil::MakeShape(F32, {8});
1226   // Build sub-computation with fusion root.
1227   auto sub_builder = HloComputation::Builder(TestName() + "_sub");
1228   auto sub_param = sub_builder.AddInstruction(
1229       HloInstruction::CreateParameter(0, shape, "sub_param"));
1230   auto one = sub_builder.AddInstruction(
1231       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1232   auto ones = sub_builder.AddInstruction(
1233       HloInstruction::CreateBroadcast(shape, one, {1}));
1234   auto add = sub_builder.AddInstruction(
1235       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sub_param, ones));
1236 
1237   module_ = CreateNewUnverifiedModule();
1238   auto sub_computation = module_->AddEmbeddedComputation(sub_builder.Build());
1239   sub_computation->CreateFusionInstruction({add, ones},
1240                                            HloInstruction::FusionKind::kLoop);
1241 
1242   // Build entry-computation with kCall which calls 'sub_computation'.
1243   auto builder = HloComputation::Builder(TestName());
1244 
1245   auto param = builder.AddInstruction(
1246       HloInstruction::CreateParameter(0, shape, "param"));
1247   auto reverse =
1248       builder.AddInstruction(HloInstruction::CreateReverse(shape, param, {0}));
1249   auto call = builder.AddInstruction(
1250       HloInstruction::CreateCall(shape, {reverse}, sub_computation));
1251   computation_ = module_->AddEntryComputation(builder.Build());
1252 
1253   RunAnalysis();
1254 
1255   EXPECT_TRUE(points_to_analysis_->CanShareOperandBufferWithUser(reverse, {},
1256                                                                  call, {}));
1257 }
1258 
TEST_F(CanShareOperandBufferWithUserTest,LoopFusionWithElementwiseOperand)1259 TEST_F(CanShareOperandBufferWithUserTest, LoopFusionWithElementwiseOperand) {
1260   Shape full_shape = ShapeUtil::MakeShape(F32, {16, 32});
1261   Shape broadcast_shape = ShapeUtil::MakeShape(F32, {16});
1262 
1263   auto builder = HloComputation::Builder(TestName() + "_fusion");
1264   auto param0 = builder.AddInstruction(
1265       HloInstruction::CreateParameter(0, full_shape, "full"));
1266   auto param1 = builder.AddInstruction(
1267       HloInstruction::CreateParameter(1, broadcast_shape, "small"));
1268   auto broadcast = builder.AddInstruction(
1269       HloInstruction::CreateBroadcast(full_shape, param1, {0}));
1270   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
1271       full_shape, HloOpcode::kAdd, param0, broadcast));
1272 
1273   BuildModule(builder.Build());
1274   auto fusion = computation_->CreateFusionInstruction(
1275       {add, broadcast}, HloInstruction::FusionKind::kLoop);
1276   RunAnalysis();
1277 
1278   EXPECT_TRUE(points_to_analysis_->CanShareOperandBufferWithUser(param0, {},
1279                                                                  fusion, {}));
1280   EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(param1, {},
1281                                                                   fusion, {}));
1282 }
1283 
1284 }  // namespace
1285 }  // namespace xla
1286